import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad

from .pointnet_base import PointNetBase

class PointNetClassifier(nn.Module):
	def __init__(self, num_points=2000, K=3):
		super(PointNetClassifier, self).__init__()
		self.base = PointNetBase(num_points, K)
		self.classifier = nn.Sequential(
			nn.Linear(1024, 512),
			nn.BatchNorm1d(512),
			nn.ReLU(),
			nn.Dropout(0.7),
			nn.Linear(512, 256),
			nn.BatchNorm1d(256),
			nn.ReLU(),
			nn.Dropout(0.7),
			nn.Linear(256, 40))


	def forward(self, x):
		global_feature, local_embedding, T2 = self.base(x)
		num_points = local_embedding.shape[-1]
		global_feature = global_feature.unsqueeze(-1).repeat(1,1,num_points)
		point_features = torch.cat( (global_feature, local_embedding), dim=1 ) 
		return point_features
		
		

		
		